Skip to content

Commit

Permalink
Merge pull request #6 from worldcoin/0xkitsune/mpc-coordinator
Browse files Browse the repository at this point in the history
feat(coordinator): Initial logic for coordinator
  • Loading branch information
0xKitsune authored Jan 31, 2024
2 parents 447fa51 + 333312f commit 1a64018
Show file tree
Hide file tree
Showing 22 changed files with 2,333 additions and 56 deletions.
703 changes: 696 additions & 7 deletions Cargo.lock

Large diffs are not rendered by default.

30 changes: 25 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,48 @@ version = "0.1.0"
edition = "2021"

[dependencies]
aws-config = "1.1.4"
aws-sdk-sqs = "1.12.0"
bytemuck = { version = "1.14.1", features = ["derive"] }
clap = { version = "4.4.18", features = ["derive", "env"] }
config = "0.13.4"
dotenv = "0.15.0"
criterion = "0.5.1"
dotenv = "0.15.0"
eyre = "0.6.11"
futures = "0.3.30"
hex = { version = "0.4.3", features = ["serde"] }
itertools = "0.12.0"
memmap = "0.7.0"
metrics = "0.21.1"
rand = "0.8.5"
rayon = "1.8.1"
serde = { version = "1.0.195", features = ["derive"] }
serde_json = "1.0.111"
sqlx = { version = "0.7.3", features = [
"runtime-tokio-native-tls",
"any",
"postgres",
"chrono",
] }
docker-db = { git = "https://github.com/Dzejkop/docker-db", rev = "ef9a4dfccd9cb9b4babeebfea0ba815e36afbafe" }
telemetry-batteries = { git = "https://github.com/worldcoin/telemetry-batteries.git", rev = "c6816624415ae194da5203a5161621a9e10ad3b0" }
tokio = { version = "1.35.1", features = ["macros"] }
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
metrics = "0.21.1"
serde = { version = "1.0.195", features = ["derive"] }

[dev-dependencies]
float_eq = "1.0.1"
proptest = "1.4.0"

[[bin]]
name = "mpc-node"
path = "bin/mpc_node.rs"
name = "mpc-participant"
path = "bin/mpc_participant.rs"

[[bin]]
name = "mpc-coordinator"
path = "bin/mpc_coordinator.rs"

[[bench]]
name = "example"
harness = false

72 changes: 72 additions & 0 deletions bin/mpc_coordinator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use std::path::PathBuf;

use aws_config::BehaviorVersion;
use clap::Parser;
use mpc::config::CoordinatorConfig;
use mpc::coordinator::Coordinator;
use telemetry_batteries::metrics::batteries::StatsdBattery;
use telemetry_batteries::tracing::batteries::DatadogBattery;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;

pub const SERVICE_NAME: &str = "mpc-coordinator";

pub const METRICS_HOST: &str = "localhost";
pub const METRICS_PORT: u16 = 8125;
pub const METRICS_QUEUE_SIZE: usize = 5000;
pub const METRICS_BUFFER_SIZE: usize = 1024;
pub const METRICS_PREFIX: &str = "mpc-coordinator";

#[derive(Parser)]
#[clap(version)]
pub struct Args {
#[clap(short, long, env)]
telemetry: bool,

#[clap(short, long, env)]
config: Option<PathBuf>,
}

#[tokio::main]
async fn main() -> eyre::Result<()> {
dotenv::dotenv().ok();

let args = Args::parse();

if args.telemetry {
DatadogBattery::init(None, SERVICE_NAME, None, true);

StatsdBattery::init(
METRICS_HOST,
METRICS_PORT,
METRICS_QUEUE_SIZE,
METRICS_BUFFER_SIZE,
Some(METRICS_PREFIX),
)?;
} else {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer().pretty().compact())
.with(tracing_subscriber::EnvFilter::from_default_env())
.init();
}

let mut settings = config::Config::builder();

if let Some(path) = args.config {
settings = settings.add_source(config::File::from(path).required(true));
}

let settings = settings
.add_source(config::Environment::with_prefix("MPC").separator("__"))
.build()?;

let config = settings.try_deserialize::<CoordinatorConfig>()?;

let coordinator =
Coordinator::new(vec![], "template_queue_url", "distance_queue_url")
.await?;

coordinator.spawn().await?;

Ok(())
}
20 changes: 1 addition & 19 deletions bin/mpc_node.rs → bin/mpc_participant.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::path::PathBuf;

use clap::Parser;
use mpc::config::Config;
use telemetry_batteries::metrics::batteries::StatsdBattery;
use telemetry_batteries::tracing::batteries::DatadogBattery;
use tracing_subscriber::layer::SubscriberExt;
Expand Down Expand Up @@ -60,22 +59,5 @@ async fn main() -> eyre::Result<()> {
.add_source(config::Environment::with_prefix("MPC").separator("__"))
.build()?;

let config = settings.try_deserialize::<Config>()?;

let mut n = 0;

loop {
foo(&config, n).await;

n += 1;

tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
}
}

#[tracing::instrument(skip(config))]
async fn foo(config: &Config, n: usize) {
tracing::info!(n, test = config.test.test, "Foo");

metrics::gauge!("foo", n as f64);
Ok(())
}
4 changes: 4 additions & 0 deletions migrations/coordinator/001_init.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CREATE TABLE masks (
id BIGINT PRIMARY KEY,
mask BYTEA NOT NULL
);
4 changes: 4 additions & 0 deletions migrations/participant/001_init.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CREATE TABLE shares (
id BIGINT PRIMARY KEY,
share BYTEA NOT NULL
);
2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[toolchain]
channel = "nightly-2024-01-26"
channel = "nightly-2024-01-25"
components = ["rustc-dev", "rustc", "cargo", "rustfmt", "clippy"]
100 changes: 100 additions & 0 deletions src/arch/generic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#![allow(unused)]
use std::cmp::min;
use std::mem::swap;
use std::thread::JoinHandle;

use rayon::prelude::*;

use crate::distance::Bits;
use crate::encoded_bits::EncodedBits;

pub fn distances<'a>(
query: &'a EncodedBits,
db: &'a [EncodedBits],
) -> impl Iterator<Item = [u16; 31]> + 'a {
const BATCH: usize = 10_000;

// Prepare 31 rotations of query in advance
let rotations: Box<[_]> = (-15..=15).map(|r| query.rotated(r)).collect();

// Iterate over a batch of database entries
db.chunks(BATCH).flat_map(move |chunk| {
let mut results = [[0_u16; 31]; BATCH];

// Parallel computation over batch
results.par_iter_mut().zip(chunk.par_iter()).for_each(
|(result, entry)| {
// Compute dot product for each rotation
for (d, rotation) in result.iter_mut().zip(rotations.iter()) {
*d = rotation.dot(entry);
}
},
);

// Sequentially output results
results.into_iter().take(chunk.len())
})
}

pub fn denominators<'a>(
query: &'a Bits,
db: &'a [Bits],
) -> impl Iterator<Item = [u16; 31]> + 'a {
const BATCH: usize = 10_000;

// Prepare 31 rotations of query in advance
let rotations: Box<[_]> = (-15..=15).map(|r| query.rotated(r)).collect();

// Iterate over a batch of database entries
db.chunks(BATCH).flat_map(move |chunk| {
// Parallel computation over batch
let results = chunk
.par_iter()
.map(|(entry)| {
let mut result = [0_u16; 31];
// Compute dot product for each rotation
for (d, rotation) in result.iter_mut().zip(rotations.iter()) {
*d = rotation.dot(entry);
}
result
})
.collect::<Vec<_>>();

// Sequentially output results
results.into_iter().take(chunk.len())
})
}

//TODO: move this to the benches file
#[cfg(feature = "bench")]
pub mod benches {
use core::hint::black_box;

use criterion::Criterion;
use rand::{thread_rng, Rng};

use super::*;

pub fn group(c: &mut Criterion) {
let mut rng = thread_rng();
let mut g = c.benchmark_group("generic");

g.bench_function("distances 31x1000", |bench| {
let a: EncodedBits = rng.gen();
let b: Box<[EncodedBits]> = (0..1000).map(|_| rng.gen()).collect();
bench.iter(|| {
black_box(distances(black_box(&a), black_box(&b)))
.for_each(|_| {})
})
});

g.bench_function("denominators 31x1000", |bench| {
let a: Bits = rng.gen();
let b: Box<[Bits]> = (0..1000).map(|_| rng.gen()).collect();
bench.iter(|| {
black_box(denominators(black_box(&a), black_box(&b)))
.for_each(|_| {})
})
});
}
}
22 changes: 22 additions & 0 deletions src/arch/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
mod generic; // Optimized generic implementation
mod neon; // Optimized aarch64 NEON implementation
mod reference; // Simple generic implementations

pub use generic::{denominators, distances};

//TODO: move this to the benches file
#[cfg(feature = "bench")]
pub mod benches {
use criterion::Criterion;

use super::*;

pub fn group(c: &mut Criterion) {
reference::benches::group(c);

generic::benches::group(c);

#[cfg(target_feature = "neon")]
neon::benches::group(c);
}
}
20 changes: 20 additions & 0 deletions src/arch/neon.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#![cfg(target_feature = "neon")]
#![allow(unused)]

// Rust + LLVM already generates good NEON code for the generic implementation.

//TODO: move this to the benches file
#[cfg(feature = "bench")]
pub mod benches {
use core::hint::black_box;

use criterion::Criterion;
use rand::{thread_rng, Rng};

use super::*;

pub fn group(c: &mut Criterion) {
let mut g = c.benchmark_group("neon");
let mut rng = thread_rng();
}
}
64 changes: 64 additions & 0 deletions src/arch/reference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#![allow(unused)]

use crate::distance::Bits;
use crate::encoded_bits::EncodedBits;

pub fn distances<'a>(
query: &'a EncodedBits,
db: &'a [EncodedBits],
) -> impl Iterator<Item = [u16; 31]> + 'a {
db.iter().map(|entry| {
let mut result = [0_u16; 31];
for (d, r) in result.iter_mut().zip(-15..=15) {
*d = query.rotated(r).dot(entry);
}
result
})
}

pub fn denominators<'a>(
query: &'a Bits,
db: &'a [Bits],
) -> impl Iterator<Item = [u16; 31]> + 'a {
db.iter().map(|entry| {
let mut result = [0_u16; 31];
for (d, r) in result.iter_mut().zip(-15..=15) {
*d = query.rotated(r).dot(entry);
}
result
})
}

//TODO: move this to the benches file
#[cfg(feature = "bench")]
pub mod benches {
use core::hint::black_box;

use criterion::Criterion;
use rand::{thread_rng, Rng};

use super::*;

pub fn group(c: &mut Criterion) {
let mut rng = thread_rng();
let mut g = c.benchmark_group("reference");

g.bench_function("distances 31x1000", |bench| {
let a: EncodedBits = rng.gen();
let b: Box<[EncodedBits]> = (0..1000).map(|_| rng.gen()).collect();
bench.iter(|| {
black_box(distances(black_box(&a), black_box(&b)))
.for_each(|_| {})
})
});

g.bench_function("denominators 31x1000", |bench| {
let a: Bits = rng.gen();
let b: Box<[Bits]> = (0..1000).map(|_| rng.gen()).collect();
bench.iter(|| {
black_box(denominators(black_box(&a), black_box(&b)))
.for_each(|_| {})
})
});
}
}
Loading

0 comments on commit 1a64018

Please sign in to comment.